{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ace56453",
   "metadata": {
    "id": "ace56453"
   },
   "source": [
    "![MuJoCo banner](https://raw.githubusercontent.com/google-deepmind/mujoco/main/banner.png)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "### Copyright notice\n",
    "\n",
    "> <p><small><small>Copyright 2025 DeepMind Technologies Limited.</small></p>\n",
    "> <p><small><small>Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at <a href=\"http://www.apache.org/licenses/LICENSE-2.0\">http://www.apache.org/licenses/LICENSE-2.0</a>.</small></small></p>\n",
    "> <p><small><small>Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.</small></small></p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4be2d8a0",
   "metadata": {},
   "source": [
    "# Tutorial\n",
    "\n",
    "This notebook is the first of two that demonstrate pixels-to-actions training using the experimental Madrona rendering backend.\n",
    "\n",
    "### Usage\n",
    "\n",
    "This is a simplified version of the main cartpole balance [tutorial](https://github.com/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_1.ipynb), adjusted to run on a hosted instance with a T4 GPU."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "574fcff5",
   "metadata": {
    "id": "574fcff5"
   },
   "source": [
    "### Build and Install Madrona MJX (5 min on hosted instance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be30d4a9",
   "metadata": {
    "id": "be30d4a9"
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "cat >setup.bash <<EOF\n",
    "\n",
    "# Ensure Jax matches the CUDA that will be used to build Madrona MJX\n",
    "!pip uninstall -y jax\n",
    "!pip install jax[\"cuda12_local\"]==0.4.35\n",
    "\n",
    "! sudo apt install libx11-dev libxrandr-dev libxinerama-dev libxcursor-dev libxi-dev mesa-common-dev\n",
    "\n",
    "# Prevent Python from trying to import the source packages.\n",
    "mkdir modules\n",
    "cd modules\n",
    "\n",
    "# -- 7) Install madrona-mjx\n",
    "echo \"Cloning madrona-mjx...\"\n",
    "git clone https://github.com/shacklettbp/madrona_mjx.git\n",
    "\n",
    "cd madrona_mjx\n",
    "git submodule update --init --recursive\n",
    "\n",
    "mkdir build\n",
    "cd build\n",
    "cmake -DLOAD_VULKAN=OFF ..\n",
    "make -j 8\n",
    "\n",
    "cd ..\n",
    "pip install -e .\n",
    "EOF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9758b8b6",
   "metadata": {
    "id": "9758b8b6"
   },
   "outputs": [],
   "source": [
    "import subprocess\n",
    "# Run the setup.bash script\n",
    "subprocess.run([\"bash\", \"setup.bash\"], check=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "HizwPjwEEEei",
   "metadata": {
    "id": "HizwPjwEEEei"
   },
   "source": [
    "Now, after installing Madrona MJX, please **restart your colab** instance! (`Runtime -> Restart session`)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a90aa84",
   "metadata": {
    "id": "1a90aa84"
   },
   "source": [
    "### Install Playground and Pre-requisites"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f617d7bc",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "f617d7bc",
    "outputId": "5a27109f-820f-4701-838d-ea3cd257afae"
   },
   "outputs": [],
   "source": [
    "!pip install mujoco\n",
    "!pip install mujoco_mjx\n",
    "!pip install brax\n",
    "!pip install playground\n",
    "!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n",
    "!pip install -q mediapy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e8b1475",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "4e8b1475",
    "outputId": "1120b827-f11f-4296-e655-dd69c873c440"
   },
   "outputs": [],
   "source": [
    "# @title Configuration for both local and for Colab instances.\n",
    "\n",
    "# On your second reading, load the compiled rendering backend to save time!\n",
    "import os\n",
    "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n",
    "\n",
    "# Check if MuJoCo installation was successful\n",
    "import distutils.util\n",
    "import os\n",
    "import subprocess\n",
    "if subprocess.run('nvidia-smi').returncode:\n",
    "    raise RuntimeError(\n",
    "        'Cannot communicate with GPU. '\n",
    "        'Make sure you are using a GPU Colab runtime. '\n",
    "        'Go to the Runtime menu and select Choose runtime type.'\n",
    "    )\n",
    "\n",
    "# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n",
    "# This is usually installed as part of an Nvidia driver package, but the Colab\n",
    "# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n",
    "# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n",
    "NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'\n",
    "if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n",
    "    with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:\n",
    "        f.write(\"\"\"{\n",
    "        \"file_format_version\" : \"1.0.0\",\n",
    "        \"ICD\" : {\n",
    "            \"library_path\" : \"libEGL_nvidia.so.0\"\n",
    "        }\n",
    "    }\n",
    "    \"\"\")\n",
    "\n",
    "# Configure MuJoCo to use the EGL rendering backend (requires GPU)\n",
    "print('Setting environment variable to use GPU rendering:')\n",
    "%env MUJOCO_GL=egl\n",
    "\n",
    "try:\n",
    "    print('Checking that the installation succeeded:')\n",
    "    import mujoco\n",
    "\n",
    "    mujoco.MjModel.from_xml_string('<mujoco/>')\n",
    "except Exception as e:\n",
    "    raise e from RuntimeError(\n",
    "        'Something went wrong during installation. Check the shell output above '\n",
    "        'for more information.\\n'\n",
    "        'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n",
    "        'by going to the Runtime menu and selecting \"Choose runtime type\".'\n",
    "    )\n",
    "\n",
    "print('Installation successful.')\n",
    "\n",
    "# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs\n",
    "xla_flags = os.environ.get('XLA_FLAGS', '')\n",
    "xla_flags += ' --xla_gpu_triton_gemm_any=True'\n",
    "os.environ['XLA_FLAGS'] = xla_flags"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ce3f23a",
   "metadata": {
    "id": "7ce3f23a"
   },
   "outputs": [],
   "source": [
    "# @title Import packages for plotting and creating graphics\n",
    "import json\n",
    "import itertools\n",
    "import time\n",
    "from typing import Callable, List, NamedTuple, Optional, Union\n",
    "import numpy as np\n",
    "\n",
    "# Graphics and plotting.\n",
    "import mediapy as media\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# More legible printing from numpy.\n",
    "np.set_printoptions(precision=3, suppress=True, linewidth=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92bcc22b",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "92bcc22b",
    "lines_to_next_cell": 2,
    "outputId": "57b83db5-132f-469b-c008-62e65d759f97"
   },
   "outputs": [],
   "source": [
    "# @title Import MuJoCo, MJX, and Brax\n",
    "from datetime import datetime\n",
    "import functools\n",
    "import os\n",
    "import time\n",
    "\n",
    "from brax.training.agents.ppo import networks_vision as ppo_networks_vision\n",
    "from brax.training.agents.ppo import train as ppo\n",
    "from IPython.display import clear_output\n",
    "import jax\n",
    "from jax import numpy as jp\n",
    "from matplotlib import pyplot as plt\n",
    "import mediapy as media\n",
    "import numpy as np\n",
    "\n",
    "from mujoco_playground import wrapper\n",
    "\n",
    "np.set_printoptions(precision=3, suppress=True, linewidth=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0902bd0d",
   "metadata": {
    "id": "0902bd0d"
   },
   "source": [
    "Now, let's load Mujoco Playground's JAX Cartpole re-implementation. While we're at it, we'll configure it for the upcoming training. Please note that calling `dm_control_suite.load` builds the rendering backend, taking around **3 minutes** on a T4 instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "743928b0",
   "metadata": {
    "id": "743928b0",
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "from mujoco_playground import dm_control_suite\n",
    "\n",
    "\n",
    "num_envs = 512\n",
    "ctrl_dt = 0.04\n",
    "episode_length = int(3 / ctrl_dt)\n",
    "\n",
    "config_overrides = {\n",
    "    \"vision\": True,\n",
    "    \"vision_config.render_batch_size\": num_envs,\n",
    "    \"action_repeat\": 1,\n",
    "    \"ctrl_dt\": ctrl_dt,\n",
    "    \"episode_length\": episode_length,\n",
    "}\n",
    "\n",
    "env_name = \"CartpoleBalance\"\n",
    "env = dm_control_suite.load(\n",
    "    env_name, config_overrides=config_overrides\n",
    ")\n",
    "\n",
    "env = wrapper.wrap_for_brax_training(\n",
    "    env,\n",
    "    vision=True,\n",
    "    num_vision_envs=num_envs,\n",
    "    action_repeat=1,\n",
    "    episode_length=episode_length,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4d70a08",
   "metadata": {
    "id": "c4d70a08",
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "jit_reset = jax.jit(env.reset)\n",
    "jit_step = jax.jit(env.step)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ab3a297",
   "metadata": {
    "id": "7ab3a297"
   },
   "source": [
    "## Balancing a Cartpole\n",
    "\n",
    "Now, let's train a pixels-to-torque policy to balance a cartpole. In [our implementation]((https://github.com/google-deepmind/mujoco_playground/blob/main/mujoco_playground/_src/dm_control_suite/cartpole.py)), we control the observation dimensionality while encouraging the agent to understand dynamics by setting the observations as grayscale images stacked across sequential timesteps."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59f3b3ac",
   "metadata": {
    "id": "59f3b3ac"
   },
   "source": [
    "#### Visualize the environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b1286e8",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 297
    },
    "id": "6b1286e8",
    "outputId": "449f5b20-0f85-400a-e78e-1923ccd5c750"
   },
   "outputs": [],
   "source": [
    "def unvmap(x):\n",
    "  return jax.tree.map(lambda y: y[0], x)\n",
    "\n",
    "\n",
    "state = jit_reset(jax.random.split(jax.random.PRNGKey(0), num_envs))\n",
    "rollout = [unvmap(state)]\n",
    "\n",
    "f = 0.2\n",
    "for i in range(episode_length):\n",
    "  action = []\n",
    "  for j in range(env.action_size):\n",
    "    action.append(\n",
    "        jp.sin(\n",
    "            unvmap(state).data.time * 2 * jp.pi * f\n",
    "            + j * 2 * jp.pi / env.action_size\n",
    "        )\n",
    "    )\n",
    "  action = jp.tile(jp.array(action), (num_envs, 1))\n",
    "  state = jit_step(state, action)\n",
    "  rollout.append(unvmap(state))\n",
    "\n",
    "frames = env.render(rollout, camera=\"fixed\", width=256, height=256)\n",
    "k = next(iter(rollout[0].obs.items()), None)[0]  # ex: pixels/view_0\n",
    "obs = [r.obs[k][..., 0] for r in rollout]  # visualise first channel\n",
    "\n",
    "media.show_videos([frames, obs], fps=1.0 / env.dt, height=256)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "686f6663",
   "metadata": {
    "id": "686f6663"
   },
   "source": [
    "#### Train\n",
    "Training the policy takes 4 minutes on a hosted T4 instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06333cd4",
   "metadata": {
    "id": "06333cd4",
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "from mujoco_playground.config import dm_control_suite_params\n",
    "\n",
    "# Load vision-specific PPO configuration tuned for CartpoleBalance\n",
    "ppo_params = dm_control_suite_params.brax_vision_ppo_config(env_name)\n",
    "ppo_params.episode_length = episode_length\n",
    "ppo_params.network_factory = ppo_networks_vision.make_ppo_networks_vision\n",
    "ppo_params['batch_size'] = 128\n",
    "ppo_params['num_envs'] = ppo_params['num_eval_envs'] = num_envs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f419702",
   "metadata": {
    "id": "2f419702",
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "x_data, y_data, y_dataerr = [], [], []\n",
    "times = [datetime.now()]\n",
    "\n",
    "\n",
    "def progress(num_steps, metrics):\n",
    "  clear_output(wait=True)\n",
    "\n",
    "  times.append(datetime.now())\n",
    "  x_data.append(num_steps)\n",
    "  y_data.append(metrics[\"eval/episode_reward\"])\n",
    "  y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
    "\n",
    "  plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n",
    "  plt.ylim([0, 100])\n",
    "  plt.xlabel(\"# environment steps\")\n",
    "  plt.ylabel(\"reward per episode\")\n",
    "  plt.title(f\"y={y_data[-1]:.3f}\")\n",
    "  plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
    "\n",
    "  display(plt.gcf())\n",
    "\n",
    "\n",
    "train_fn = functools.partial(\n",
    "    ppo.train, **dict(ppo_params), progress_fn=progress\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fb88915",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 962
    },
    "id": "3fb88915",
    "outputId": "a8d7c74b-3fa2-4441-f2da-662b4d3162b4"
   },
   "outputs": [],
   "source": [
    "make_inference_fn, params, metrics = train_fn(environment=env)\n",
    "print(f\"time to jit: {times[1] - times[0]}\")\n",
    "print(f\"time to train: {times[-1] - times[1]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "076c1ca1",
   "metadata": {
    "id": "076c1ca1"
   },
   "source": [
    "#### Visualize the Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd8dec1",
   "metadata": {
    "id": "0cd8dec1",
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "jit_reset = jax.jit(env.reset)\n",
    "jit_step = jax.jit(env.step)\n",
    "jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64532268",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 713
    },
    "id": "64532268",
    "outputId": "1d8c7315-93a5-4345-d888-3f9256966f81"
   },
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(42)\n",
    "rollout = []\n",
    "n_episodes = 1\n",
    "\n",
    "for _ in range(n_episodes):\n",
    "  key_rng = jax.random.split(rng, num_envs)\n",
    "  state = jit_reset(key_rng)\n",
    "  rollout.append(unvmap(state))\n",
    "  for i in range(episode_length):\n",
    "    act_rng, rng = jax.random.split(rng)\n",
    "    act_rng = jax.random.split(act_rng, num_envs)\n",
    "    ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
    "    state = jit_step(state, ctrl)\n",
    "    rollout.append(unvmap(state))\n",
    "\n",
    "render_every = 1\n",
    "frames = env.render(rollout[::render_every], camera=\"fixed\")\n",
    "rewards = [s.reward for s in rollout]\n",
    "media.show_video(frames, fps=1.0 / env.dt / render_every)\n",
    "\n",
    "plt.plot(np.convolve(rewards, np.ones(10) / 10, mode=\"valid\"))\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"reward\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eddb3fc3",
   "metadata": {
    "id": "eddb3fc3"
   },
   "source": [
    "🙌 Stay tuned for more hosted instances with Madrona-MJX rendering support!"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "jupytext": {
   "cell_metadata_filter": "-all",
   "encoding": "# coding: utf-8",
   "executable": "/usr/bin/env python",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
